from __future__ import annotations
from collections import defaultdict, OrderedDict
from typing import Any
from collections.abc import Callable
from collections.abc import Generator
import operator
import logging

import ailment
import claripy
import networkx
from unique_log_filter import UniqueLogFilter


from angr.utils.graph import GraphUtils
from angr.utils.lazy_import import lazy_import
from angr.utils import is_pyinstaller
from angr.utils.graph import dominates, inverted_idoms
from angr.block import Block, BlockNode
from angr.errors import AngrRuntimeError
from .peephole_optimizations import InvertNegatedLogicalConjunctionsAndDisjunctions
from .structuring.structurer_nodes import (
    MultiNode,
    EmptyBlockNotice,
    SequenceNode,
    CodeNode,
    SwitchCaseNode,
    BreakNode,
    ConditionalBreakNode,
    LoopNode,
    ConditionNode,
    ContinueNode,
    CascadingConditionNode,
    IncompleteSwitchCaseNode,
)
from .graph_region import GraphRegion
from .utils import first_nonlabel_nonphi_statement, peephole_optimize_expr

if is_pyinstaller():
    # PyInstaller is not happy with lazy import
    import sympy
else:
    sympy = lazy_import("sympy")


l = logging.getLogger(__name__)
l.addFilter(UniqueLogFilter())


_UNIFIABLE_COMPARISONS = {
    "__ne__",
    "__gt__",
    "__ge__",
    "UGT",
    "UGE",
    "SGT",
    "SGE",
}


_INVERSE_OPERATIONS = {
    "__eq__": "__ne__",
    "__ne__": "__eq__",
    "__gt__": "__le__",
    "__lt__": "__ge__",
    "__ge__": "__lt__",
    "__le__": "__gt__",
    "ULT": "UGE",
    "UGE": "ULT",
    "UGT": "ULE",
    "ULE": "UGT",
    "SLT": "SGE",
    "SGE": "SLT",
    "SLE": "SGT",
    "SGT": "SLE",
}


#
# Util methods and mapping used during AIL AST to claripy AST conversion
#


def _op_with_unified_size(op, conv: Callable, operand0, operand1, ins_addr: int):
    # ensure operand1 is of the same size as operand0
    if isinstance(operand1, ailment.Expr.Const):
        # amazing - we do the easy thing here
        return op(conv(operand0, nobool=True, ins_addr=ins_addr), operand1.value)
    if operand1.bits == operand0.bits:
        return op(conv(operand0, nobool=True, ins_addr=ins_addr), conv(operand1, ins_addr=ins_addr))
    # extension is required
    assert operand1.bits < operand0.bits
    operand1 = ailment.Expr.Convert(None, operand1.bits, operand0.bits, False, operand1)
    return op(conv(operand0, nobool=True, ins_addr=ins_addr), conv(operand1, nobool=True, ins_addr=ins_addr))


def _dummy_bvs(condition, condition_mapping, name_suffix=""):
    var = claripy.BVS(f"ailexpr_{condition!r}{name_suffix}", condition.bits, explicit_name=True)
    condition_mapping[var.args[0]] = condition
    return var


def _dummy_bools(condition, condition_mapping, name_suffix=""):
    var = claripy.BoolS(f"ailexpr_{condition!r}{name_suffix}", explicit_name=True)
    condition_mapping[var.args[0]] = condition
    return var


_ail2claripy_op_mapping = {
    "LogicalAnd": lambda expr, conv, _, ia: claripy.And(
        conv(expr.operands[0], ins_addr=ia), conv(expr.operands[1], ins_addr=ia)
    ),
    "LogicalOr": lambda expr, conv, _, ia: claripy.Or(
        conv(expr.operands[0], ins_addr=ia), conv(expr.operands[1], ins_addr=ia)
    ),
    "CmpEQ": lambda expr, conv, _, ia: conv(expr.operands[0], nobool=True, ins_addr=ia)
    == conv(expr.operands[1], nobool=True, ins_addr=ia),
    "CmpNE": lambda expr, conv, _, ia: conv(expr.operands[0], nobool=True, ins_addr=ia)
    != conv(expr.operands[1], nobool=True, ins_addr=ia),
    "CmpLE": lambda expr, conv, _, ia: conv(expr.operands[0], nobool=True, ins_addr=ia)
    <= conv(expr.operands[1], nobool=True, ins_addr=ia),
    "CmpLE (signed)": lambda expr, conv, _, ia: claripy.SLE(
        conv(expr.operands[0], nobool=True, ins_addr=ia), conv(expr.operands[1], nobool=True, ins_addr=ia)
    ),
    "CmpLT": lambda expr, conv, _, ia: conv(expr.operands[0], nobool=True, ins_addr=ia)
    < conv(expr.operands[1], nobool=True, ins_addr=ia),
    "CmpLT (signed)": lambda expr, conv, _, ia: claripy.SLT(
        conv(expr.operands[0], nobool=True, ins_addr=ia), conv(expr.operands[1], nobool=True, ins_addr=ia)
    ),
    "CmpGE": lambda expr, conv, _, ia: conv(expr.operands[0], nobool=True, ins_addr=ia)
    >= conv(expr.operands[1], nobool=True, ins_addr=ia),
    "CmpGE (signed)": lambda expr, conv, _, ia: claripy.SGE(
        conv(expr.operands[0], nobool=True, ins_addr=ia), conv(expr.operands[1], nobool=True, ins_addr=ia)
    ),
    "CmpGT": lambda expr, conv, _, ia: conv(expr.operands[0], nobool=True, ins_addr=ia)
    > conv(expr.operands[1], nobool=True, ins_addr=ia),
    "CmpGT (signed)": lambda expr, conv, _, ia: claripy.SGT(
        conv(expr.operands[0], nobool=True, ins_addr=ia), conv(expr.operands[1], nobool=True, ins_addr=ia)
    ),
    "CasCmpEQ": lambda expr, conv, _, ia: conv(expr.operands[0], nobool=True, ins_addr=ia)
    == conv(expr.operands[1], nobool=True, ins_addr=ia),
    "CasCmpNE": lambda expr, conv, _, ia: conv(expr.operands[0], nobool=True, ins_addr=ia)
    != conv(expr.operands[1], nobool=True, ins_addr=ia),
    "CasCmpLE": lambda expr, conv, _, ia: conv(expr.operands[0], nobool=True, ins_addr=ia)
    <= conv(expr.operands[1], nobool=True, ins_addr=ia),
    "CasCmpLE (signed)": lambda expr, conv, _, ia: claripy.SLE(
        conv(expr.operands[0], nobool=True, ins_addr=ia), conv(expr.operands[1], nobool=True, ins_addr=ia)
    ),
    "CasCmpLT": lambda expr, conv, _, ia: conv(expr.operands[0], nobool=True, ins_addr=ia)
    < conv(expr.operands[1], nobool=True, ins_addr=ia),
    "CasCmpLT (signed)": lambda expr, conv, _, ia: claripy.SLT(
        conv(expr.operands[0], nobool=True, ins_addr=ia), conv(expr.operands[1], nobool=True, ins_addr=ia)
    ),
    "CasCmpGE": lambda expr, conv, _, ia: conv(expr.operands[0], nobool=True, ins_addr=ia)
    >= conv(expr.operands[1], nobool=True, ins_addr=ia),
    "CasCmpGE (signed)": lambda expr, conv, _, ia: claripy.SGE(
        conv(expr.operands[0], nobool=True, ins_addr=ia), conv(expr.operands[1], nobool=True, ins_addr=ia)
    ),
    "CasCmpGT": lambda expr, conv, _, ia: conv(expr.operands[0], nobool=True, ins_addr=ia)
    > conv(expr.operands[1], nobool=True, ins_addr=ia),
    "CasCmpGT (signed)": lambda expr, conv, _, ia: claripy.SGT(
        conv(expr.operands[0], nobool=True, ins_addr=ia), conv(expr.operands[1], nobool=True, ins_addr=ia)
    ),
    "Add": lambda expr, conv, _, ia: conv(expr.operands[0], nobool=True, ins_addr=ia)
    + conv(expr.operands[1], nobool=True, ins_addr=ia),
    "Sub": lambda expr, conv, _, ia: conv(expr.operands[0], nobool=True, ins_addr=ia)
    - conv(expr.operands[1], nobool=True, ins_addr=ia),
    "Mul": lambda expr, conv, _, ia: conv(expr.operands[0], nobool=True, ins_addr=ia)
    * conv(expr.operands[1], nobool=True, ins_addr=ia),
    "Div": lambda expr, conv, _, ia: conv(expr.operands[0], nobool=True, ins_addr=ia)
    / conv(expr.operands[1], nobool=True, ins_addr=ia),
    "Mod": lambda expr, conv, _, ia: conv(expr.operands[0], nobool=True, ins_addr=ia)
    % conv(expr.operands[1], nobool=True, ins_addr=ia),
    "Not": lambda expr, conv, _, ia: claripy.Not(conv(expr.operand, ins_addr=ia)),
    "Neg": lambda expr, conv, _, ia: -conv(expr.operand, ins_addr=ia),
    "BitwiseNeg": lambda expr, conv, _, ia: ~conv(expr.operand, ins_addr=ia),
    "Xor": lambda expr, conv, _, ia: conv(expr.operands[0], nobool=True, ins_addr=ia)
    ^ conv(expr.operands[1], nobool=True, ins_addr=ia),
    "And": lambda expr, conv, _, ia: conv(expr.operands[0], nobool=True, ins_addr=ia)
    & conv(expr.operands[1], nobool=True, ins_addr=ia),
    "Or": lambda expr, conv, _, ia: conv(expr.operands[0], nobool=True, ins_addr=ia)
    | conv(expr.operands[1], nobool=True, ins_addr=ia),
    "Shr": lambda expr, conv, _, ia: _op_with_unified_size(claripy.LShR, conv, expr.operands[0], expr.operands[1], ia),
    "Shl": lambda expr, conv, _, ia: _op_with_unified_size(
        operator.lshift, conv, expr.operands[0], expr.operands[1], ia
    ),
    "Sar": lambda expr, conv, _, ia: _op_with_unified_size(
        operator.rshift, conv, expr.operands[0], expr.operands[1], ia
    ),
    "Concat": lambda expr, conv, _, ia: claripy.Concat(*[conv(operand, ins_addr=ia) for operand in expr.operands]),
    # There are no corresponding claripy operations for the following operations
    "CmpF": lambda expr, _, m, *args: _dummy_bvs(expr, m),
    "Mull": lambda expr, _, m, *args: _dummy_bvs(expr, m),
    "Mull (signed)": lambda expr, _, m, *args: _dummy_bvs(expr, m),
    "Reinterpret": lambda expr, _, m, *args: _dummy_bvs(expr, m),
    "Rol": lambda expr, _, m, *args: _dummy_bvs(expr, m),
    "Ror": lambda expr, _, m, *args: _dummy_bvs(expr, m),
    "LogicalXor": lambda expr, _, m, *args: _dummy_bvs(expr, m),
    "Carry": lambda expr, _, m, *args: _dummy_bvs(expr, m),
    "SCarry": lambda expr, _, m, *args: _dummy_bvs(expr, m),
    "SBorrow": lambda expr, _, m, *args: _dummy_bvs(expr, m),
    "ExpCmpNE": lambda expr, _, m, *args: _dummy_bools(expr, m),
    "CmpORD": lambda expr, _, m, *args: _dummy_bvs(expr, m),  # in case CmpORDRewriter fails
    "CmpEQV": lambda expr, _, m, *args: _dummy_bvs(expr, m),
    "GetMSBs": lambda expr, _, m, *args: _dummy_bvs(expr, m),
    "ShlNV": lambda expr, _, m, *args: _dummy_bvs(expr, m),
    "ShrNV": lambda expr, _, m, *args: _dummy_bvs(expr, m),
    "InterleaveLOV": lambda expr, _, m, *args: _dummy_bvs(expr, m),
    "InterleaveHIV": lambda expr, _, m, *args: _dummy_bvs(expr, m),
    # catch-all
    "_DUMMY_": lambda expr, _, m, *args: _dummy_bvs(expr, m),
}

#
# The ConditionProcessor class
#


class ConditionProcessor:
    """
    Convert between claripy AST and AIL expressions. Also calculates reaching conditions of all nodes on a graph.
    """

    def __init__(self, arch, condition_mapping=None):
        self.arch = arch
        self._condition_mapping: dict[str, Any] = {} if condition_mapping is None else condition_mapping
        self.jump_table_conds: dict[int, set] = defaultdict(set)
        self.edge_conditions = {}
        self.reaching_conditions = {}
        self.guarding_conditions = {}
        self._ast2annotations = {}

        self._peephole_expr_optimizations = [
            cls(None, None, None) for cls in [InvertNegatedLogicalConjunctionsAndDisjunctions]
        ]

    def clear(self):
        self._condition_mapping = {}
        self.jump_table_conds = defaultdict(set)
        self.reaching_conditions = {}
        self.guarding_conditions = {}
        self._ast2annotations = {}

    def recover_edge_condition(self, graph: networkx.DiGraph, src, dst):
        edge = src, dst
        edge_data = graph.get_edge_data(*edge)
        edge_type = edge_data.get("type", "transition") if edge_data is not None else "transition"
        try:
            predicate = self._extract_predicate(src, dst, edge_type)
        except EmptyBlockNotice:
            # catch empty block notice - although this should not really happen
            predicate = claripy.true()
        return predicate

    def recover_edge_conditions(self, region, graph=None) -> dict:
        edge_conditions = {}
        # traverse the graph to recover the condition for each edge
        graph = graph or region.graph
        for src in graph.nodes():
            nodes = list(graph[src])
            if len(nodes) >= 1:
                for dst in nodes:
                    predicate = self.recover_edge_condition(graph, src, dst)
                    edge_conditions[(src, dst)] = predicate

        self.edge_conditions = edge_conditions

    def recover_reaching_conditions(
        self,
        region,
        graph=None,
        with_successors=False,
        case_entry_to_switch_head: dict[int, int] | None = None,
        simplify_conditions: bool = True,
    ):
        """
        Recover the reaching conditions for each block in an acyclic graph. Note that we assume the graph that's passed
        in is acyclic.
        """

        def _strictly_postdominates(inv_idoms, node_a, node_b):
            """
            Does node A strictly post-dominate node B on the graph?
            """
            return dominates(inv_idoms, node_a, node_b)

        self.recover_edge_conditions(region, graph=graph)
        edge_conditions = self.edge_conditions

        if graph:
            _g = graph
            head = next(node for node in graph.nodes if graph.in_degree(node) == 0)
        else:
            if with_successors and region.graph_with_successors is not None:
                _g = region.graph_with_successors
            else:
                _g = region.graph
            head = region.head

        # special handling for jump table entries - do not allow crossing between cases
        if case_entry_to_switch_head:
            _g = self._remove_crossing_edges_between_cases(_g, case_entry_to_switch_head)

        inverted_graph, idoms = inverted_idoms(_g)

        reaching_conditions = {}
        # recover the reaching condition for each node
        sorted_nodes = GraphUtils.quasi_topological_sort_nodes(_g)
        terminating_nodes = []
        for node in sorted_nodes:
            # create special conditions for all nodes that are jump table entries
            if case_entry_to_switch_head and node.addr in case_entry_to_switch_head:
                jump_target_var = self.create_jump_target_var(case_entry_to_switch_head[node.addr])
                cond = jump_target_var == claripy.BVV(node.addr, self.arch.bits)
                reaching_conditions[node] = cond
                self.jump_table_conds[case_entry_to_switch_head[node.addr]].add(cond)
                continue

            preds = _g.predecessors(node)
            reaching_condition = None

            out_degree = _g.out_degree(node)
            if out_degree == 0:
                terminating_nodes.append(node)

            if node is head:
                # the head is always reachable
                reaching_condition = claripy.true()
            elif idoms is not None and _strictly_postdominates(idoms, node, head):
                # the node that post dominates the head is always reachable
                reaching_conditions[node] = claripy.true()
            else:
                for pred in preds:
                    edge = (pred, node)
                    pred_condition = reaching_conditions.get(pred, claripy.true())
                    edge_condition = edge_conditions.get(edge, claripy.true())

                    if reaching_condition is None:
                        reaching_condition = claripy.And(pred_condition, edge_condition)
                    else:
                        reaching_condition = claripy.Or(claripy.And(pred_condition, edge_condition), reaching_condition)

            if reaching_condition is not None:
                reaching_conditions[node] = (
                    self.simplify_condition(reaching_condition) if simplify_conditions else reaching_condition
                )

        # My hypothesis: for nodes where two paths come together *and* those that cannot be further structured into
        # another if-else construct (we take the short-cut by testing if the operator is an "Or" after running our
        # condition simplifiers previously), we are better off using their "guarding conditions" instead of their
        # reaching conditions for if-else. see my super long chatlog with rhelmot on 5/14/2021.
        guarding_conditions = {}
        for the_node in sorted_nodes:
            preds = list(_g.predecessors(the_node))
            if len(preds) != 2:
                continue
            # generate a graph slice that goes from the region head to this node
            slice_nodes = list(networkx.dfs_tree(inverted_graph, the_node))
            subgraph = networkx.subgraph(_g, slice_nodes)
            # figure out which paths cause the divergence from this node
            nodes_do_not_reach_the_node = set()
            for node_ in subgraph:
                if node_ is the_node:
                    continue
                for succ in _g.successors(node_):
                    if not networkx.has_path(_g, succ, the_node):
                        nodes_do_not_reach_the_node.add(succ)

            diverging_conditions = []

            for node_ in nodes_do_not_reach_the_node:
                preds_ = list(_g.predecessors(node_))
                for pred_ in preds_:
                    if pred_ in nodes_do_not_reach_the_node:
                        continue
                    # this predecessor is the diverging node!
                    edge_ = pred_, node_
                    edge_condition = edge_conditions.get(edge_, None)
                    if edge_condition is not None:
                        diverging_conditions.append(edge_condition)

            if diverging_conditions:
                # the negation of the union of diverging conditions is the guarding condition for this node
                cond = claripy.Or(*map(claripy.Not, diverging_conditions))  # pylint:disable=bad-builtin
                guarding_conditions[the_node] = cond

        self.reaching_conditions = reaching_conditions
        self.guarding_conditions = guarding_conditions

    def remove_claripy_bool_asts(self, node, memo=None):
        # Convert claripy Bool ASTs to AIL expressions

        if memo is None:
            memo = {}

        if isinstance(node, SequenceNode):
            new_nodes = []
            for n in node.nodes:
                new_node = self.remove_claripy_bool_asts(n, memo=memo)
                new_nodes.append(new_node)
            return SequenceNode(node.addr, new_nodes)

        if isinstance(node, MultiNode):
            new_nodes = []
            for n in node.nodes:
                new_node = self.remove_claripy_bool_asts(n, memo=memo)
                new_nodes.append(new_node)
            return MultiNode(nodes=new_nodes)

        if isinstance(node, CodeNode):
            return CodeNode(
                self.remove_claripy_bool_asts(node.node, memo=memo),
                (
                    None
                    if node.reaching_condition is None
                    else self.convert_claripy_bool_ast(node.reaching_condition, memo=memo)
                ),
            )

        if isinstance(node, ConditionalBreakNode):
            return ConditionalBreakNode(
                node.addr,
                self.convert_claripy_bool_ast(node.condition, memo=memo),
                node.target,
            )

        if isinstance(node, ConditionNode):
            return ConditionNode(
                node.addr,
                (
                    None
                    if node.reaching_condition is None
                    else self.convert_claripy_bool_ast(node.reaching_condition, memo=memo)
                ),
                self.convert_claripy_bool_ast(node.condition, memo=memo),
                self.remove_claripy_bool_asts(node.true_node, memo=memo),
                self.remove_claripy_bool_asts(node.false_node, memo=memo),
            )

        if isinstance(node, CascadingConditionNode):
            cond_and_nodes = []
            for cond, child_node in node.condition_and_nodes:
                cond_and_nodes.append(
                    (
                        self.convert_claripy_bool_ast(cond, memo=memo),
                        self.remove_claripy_bool_asts(child_node, memo=memo),
                    )
                )
            else_node = None if node.else_node is None else self.remove_claripy_bool_asts(node.else_node, memo=memo)
            return CascadingConditionNode(
                node.addr,
                cond_and_nodes,
                else_node=else_node,
            )

        if isinstance(node, LoopNode):
            result = node.copy()
            result.condition = (
                self.convert_claripy_bool_ast(node.condition, memo=memo) if node.condition is not None else None
            )
            result.sequence_node = self.remove_claripy_bool_asts(node.sequence_node, memo=memo)
            return result

        if isinstance(node, SwitchCaseNode):
            return SwitchCaseNode(
                self.convert_claripy_bool_ast(node.switch_expr, memo=memo),
                OrderedDict(
                    (idx, self.remove_claripy_bool_asts(case_node, memo=memo)) for idx, case_node in node.cases.items()
                ),
                self.remove_claripy_bool_asts(node.default_node, memo=memo),
                addr=node.addr,
            )

        if isinstance(node, IncompleteSwitchCaseNode):
            return IncompleteSwitchCaseNode(
                node.addr,
                self.remove_claripy_bool_asts(node.head, memo=memo),
                [self.remove_claripy_bool_asts(case, memo=memo) for case in node.cases],
            )

        return node

    @classmethod
    def get_last_statement(cls, block):
        """
        This is the buggy version of get_last_statements, because, you know, there can always be more than one last
        statement due to the existence of branching statements (like, If-then-else). All methods using
        get_last_statement() should switch to get_last_statements() and properly handle multiple last statements.
        """
        if type(block) is SequenceNode:
            if block.nodes:
                return cls.get_last_statement(block.nodes[-1])
            raise EmptyBlockNotice
        if type(block) is CodeNode:
            return cls.get_last_statement(block.node)
        if type(block) is ailment.Block:
            if not block.statements:
                raise EmptyBlockNotice
            return block.statements[-1]
        if type(block) is Block:
            raise NotImplementedError
        if type(block) is BlockNode:
            raise NotImplementedError
        if type(block) is MultiNode:
            # get the last node
            for the_block in reversed(block.nodes):
                try:
                    return cls.get_last_statement(the_block)
                except EmptyBlockNotice:
                    continue
            raise EmptyBlockNotice
        if type(block) is LoopNode:
            return cls.get_last_statement(block.sequence_node)
        if type(block) is ConditionalBreakNode:
            return None
        if type(block) is ConditionNode:
            s = None
            if block.true_node:
                try:
                    s = cls.get_last_statement(block.true_node)
                except EmptyBlockNotice:
                    s = None
            if s is None and block.false_node:
                s = cls.get_last_statement(block.false_node)
            return s
        if type(block) is CascadingConditionNode:
            s = None
            if block.else_node is not None:
                s = cls.get_last_statement(block.else_node)
            else:
                for _, node in reversed(block.condition_and_nodes):
                    s = cls.get_last_statement(node)
                    if s is not None:
                        break
            return s
        if type(block) is BreakNode:
            return None
        if type(block) is ContinueNode:
            return None
        if type(block) is SwitchCaseNode:
            return None
        if type(block) is IncompleteSwitchCaseNode:
            return None
        if type(block) is GraphRegion:
            # normally this should not happen. however, we have test cases that trigger this case.
            return None

        raise NotImplementedError

    @classmethod
    def get_last_statements(cls, block) -> list[ailment.Stmt.Statement | None]:
        if type(block) is SequenceNode:
            for last_node in reversed(block.nodes):
                try:
                    return cls.get_last_statements(last_node)
                except EmptyBlockNotice:
                    # the node is empty. try the next one
                    continue

            raise EmptyBlockNotice

        if type(block) is CodeNode:
            return cls.get_last_statements(block.node)
        if type(block) is ailment.Block:
            if not block.statements:
                raise EmptyBlockNotice
            return [block.statements[-1]]
        if type(block) is Block:
            raise NotImplementedError
        if type(block) is BlockNode:
            raise NotImplementedError
        if type(block) is MultiNode:
            # get the last node
            for the_block in reversed(block.nodes):
                try:
                    return cls.get_last_statements(the_block)
                except EmptyBlockNotice:
                    continue
            raise EmptyBlockNotice
        if type(block) is LoopNode:
            if block.sequence_node is None:
                raise EmptyBlockNotice
            return cls.get_last_statements(block.sequence_node)
        if type(block) is ConditionalBreakNode:
            return [block]
        if type(block) is ConditionNode:
            s = []
            if block.true_node:
                try:
                    last_stmts = cls.get_last_statements(block.true_node)
                    s.extend(last_stmts)
                except EmptyBlockNotice:
                    pass
            else:
                s.append(None)
            if block.false_node:
                last_stmts = cls.get_last_statements(block.false_node)
                s.extend(last_stmts)
            else:
                s.append(None)
            return s
        if type(block) is CascadingConditionNode:
            s = []
            if block.else_node is not None:
                try:
                    last_stmts = cls.get_last_statements(block.else_node)
                    s.extend(last_stmts)
                except EmptyBlockNotice:
                    pass
            else:
                s.append(None)
            for _, node in block.condition_and_nodes:
                last_stmts = cls.get_last_statements(node)
                s.extend(last_stmts)
            return s
        if type(block) is BreakNode:
            return [block]
        if type(block) is ContinueNode:
            return [block]
        if type(block) is SwitchCaseNode:
            s = []
            for case in block.cases.values():
                s.extend(cls.get_last_statements(case))
            if block.default_node is not None:
                s.extend(cls.get_last_statements(block.default_node))
            else:
                s.append(None)
            return s
        if type(block) is IncompleteSwitchCaseNode:
            s = []
            for case in block.cases:
                s.extend(cls.get_last_statements(case))
            return s
        if type(block) is GraphRegion:
            # normally this should not happen. however, we have test cases that trigger this case.
            return []

        raise NotImplementedError

    #
    # Path predicate
    #

    EXC_COUNTER = 1000

    def _extract_predicate(self, src_block, dst_block, edge_type) -> claripy.ast.Bool:
        if edge_type == "exception":
            # TODO: THIS IS ABSOLUTELY A HACK. AT THIS MOMENT YOU SHOULD NOT ATTEMPT TO MAKE SENSE OF EXCEPTION EDGES.
            self.EXC_COUNTER += 1
            return self.claripy_ast_from_ail_condition(
                ailment.Expr.BinaryOp(
                    None,
                    "CmpEQ",
                    (
                        ailment.Expr.Register(0x400000 + self.EXC_COUNTER, None, self.EXC_COUNTER, 64),
                        ailment.Expr.Const(None, None, self.EXC_COUNTER, 64),
                    ),
                    False,
                ),
                ins_addr=dst_block.addr,
            )

        if type(src_block) is ConditionalBreakNode:
            # at this point ConditionalBreakNode stores a claripy AST
            bool_var = src_block.condition
            if src_block.target == dst_block.addr:
                return bool_var
            return claripy.Not(bool_var)

        if type(src_block) is GraphRegion:
            return claripy.true()

        # sometimes the last statement is the conditional jump. sometimes it's the first statement of the block
        if (
            isinstance(src_block, ailment.Block)
            and src_block.statements
            and isinstance(first_nonlabel_nonphi_statement(src_block), ailment.Stmt.ConditionalJump)
        ):
            last_stmt = first_nonlabel_nonphi_statement(src_block)
        else:
            last_stmt = self.get_last_statement(src_block)

        if last_stmt is None:
            return claripy.true()
        if type(last_stmt) is ailment.Stmt.Jump:
            if isinstance(last_stmt.target, ailment.Expr.Const):
                return claripy.true()
            # indirect jump
            target_ast = self.claripy_ast_from_ail_condition(last_stmt.target, ins_addr=last_stmt.ins_addr)
            return target_ast == dst_block.addr
        if type(last_stmt) is ailment.Stmt.ConditionalJump:
            bool_var = self.claripy_ast_from_ail_condition(last_stmt.condition, ins_addr=last_stmt.ins_addr)
            if isinstance(last_stmt.true_target, ailment.Expr.Const) and last_stmt.true_target.value == dst_block.addr:
                return bool_var
            return claripy.Not(bool_var)

        return claripy.true()

    #
    # Expression conversion
    #

    def _convert_extract(self, hi, lo, expr, tags, memo=None):
        # ailment does not support Extract. We translate Extract to Convert and shift.
        if lo == 0:
            return ailment.Expr.Convert(
                None,
                expr.size(),
                hi + 1,
                False,
                self.convert_claripy_bool_ast(expr, memo=memo),
                **tags,
            )

        raise NotImplementedError("This case will be implemented once encountered.")

    def convert_claripy_bool_ast(self, cond, memo=None):
        """
        Convert recovered reaching conditions from claripy ASTs to ailment Expressions

        :return: None
        """

        if memo is None:
            memo = {}
        if cond._hash in memo:
            return memo[cond._hash]
        r = self.convert_claripy_bool_ast_core(cond, memo)
        optimized_r = peephole_optimize_expr(r, self._peephole_expr_optimizations)
        r = r if optimized_r is None else optimized_r
        memo[cond._hash] = r
        return r

    def convert_claripy_bool_ast_core(self, cond, memo):
        if isinstance(cond, ailment.Expr.Expression):
            return cond

        if cond.op in {"BoolS", "BoolV"} and claripy.is_true(cond):
            return ailment.Expr.Const(None, None, True, 1)
        if cond in self._condition_mapping:
            return self._condition_mapping[cond]
        if cond.op in {"BVS", "BoolS"} and cond.args[0] in self._condition_mapping:
            return self._condition_mapping[cond.args[0]]

        def _binary_op_reduce(op, args, tags, signed=False):
            r = None
            for arg in args:
                if r is None:
                    r = self.convert_claripy_bool_ast(arg, memo=memo)
                else:
                    r = ailment.Expr.BinaryOp(
                        None, op, (r, self.convert_claripy_bool_ast(arg, memo=memo)), signed, **tags
                    )
            return r

        def _unary_op_reduce(op, arg, tags):
            r = self.convert_claripy_bool_ast(arg, memo=memo)
            # TODO: Keep track of tags
            return ailment.Expr.UnaryOp(None, op, r, **tags)

        _mapping = {
            "Not": lambda cond_, tags: _unary_op_reduce("Not", cond_.args[0], tags),
            "__neg__": lambda cond_, tags: _unary_op_reduce("Not", cond_.args[0], tags),
            "__invert__": lambda cond_, tags: _unary_op_reduce("BitwiseNeg", cond_.args[0], tags),
            "And": lambda cond_, tags: _binary_op_reduce("LogicalAnd", cond_.args, tags),
            "Or": lambda cond_, tags: _binary_op_reduce("LogicalOr", cond_.args, tags),
            "__le__": lambda cond_, tags: _binary_op_reduce("CmpLE", cond_.args, tags, signed=True),
            "SLE": lambda cond_, tags: _binary_op_reduce("CmpLE", cond_.args, tags, signed=True),
            "__lt__": lambda cond_, tags: _binary_op_reduce("CmpLT", cond_.args, tags, signed=True),
            "SLT": lambda cond_, tags: _binary_op_reduce("CmpLT", cond_.args, tags, signed=True),
            "UGT": lambda cond_, tags: _binary_op_reduce("CmpGT", cond_.args, tags),
            "UGE": lambda cond_, tags: _binary_op_reduce("CmpGE", cond_.args, tags),
            "__gt__": lambda cond_, tags: _binary_op_reduce("CmpGT", cond_.args, tags, signed=True),
            "__ge__": lambda cond_, tags: _binary_op_reduce("CmpGE", cond_.args, tags, signed=True),
            "SGT": lambda cond_, tags: _binary_op_reduce("CmpGT", cond_.args, tags, signed=True),
            "SGE": lambda cond_, tags: _binary_op_reduce("CmpGE", cond_.args, tags, signed=True),
            "ULT": lambda cond_, tags: _binary_op_reduce("CmpLT", cond_.args, tags),
            "ULE": lambda cond_, tags: _binary_op_reduce("CmpLE", cond_.args, tags),
            "__eq__": lambda cond_, tags: _binary_op_reduce("CmpEQ", cond_.args, tags),
            "__ne__": lambda cond_, tags: _binary_op_reduce("CmpNE", cond_.args, tags),
            "__add__": lambda cond_, tags: _binary_op_reduce("Add", cond_.args, tags, signed=False),
            "__sub__": lambda cond_, tags: _binary_op_reduce("Sub", cond_.args, tags),
            "__mul__": lambda cond_, tags: _binary_op_reduce("Mul", cond_.args, tags),
            "__xor__": lambda cond_, tags: _binary_op_reduce("Xor", cond_.args, tags),
            "__or__": lambda cond_, tags: _binary_op_reduce("Or", cond_.args, tags, signed=False),
            "__and__": lambda cond_, tags: _binary_op_reduce("And", cond_.args, tags),
            "__lshift__": lambda cond_, tags: _binary_op_reduce("Shl", cond_.args, tags),
            "__rshift__": lambda cond_, tags: _binary_op_reduce("Sar", cond_.args, tags),
            "__floordiv__": lambda cond_, tags: _binary_op_reduce("Div", cond_.args, tags),
            "__mod__": lambda cond_, tags: _binary_op_reduce("Mod", cond_.args, tags),
            "LShR": lambda cond_, tags: _binary_op_reduce("Shr", cond_.args, tags),
            "BVV": lambda cond_, tags: ailment.Expr.Const(None, None, cond_.args[0], cond_.size(), **tags),
            "BoolV": lambda cond_, tags: (
                ailment.Expr.Const(None, None, True, 1, **tags)
                if cond_.args[0] is True
                else ailment.Expr.Const(None, None, False, 1, **tags)
            ),
            "Extract": lambda cond_, tags: self._convert_extract(*cond_.args, tags, memo=memo),
            "ZeroExt": lambda cond_, tags: _binary_op_reduce(
                "Concat", [claripy.BVV(0, cond_.args[0]), cond_.args[1]], tags
            ),
            "Concat": lambda cond_, tags: _binary_op_reduce("Concat", cond_.args, tags),
        }

        if cond.op in _mapping:
            if cond in self._ast2annotations:
                cond_tags = self._ast2annotations.get(cond)
            elif claripy.Not(cond) in self._ast2annotations:
                cond_tags = self._ast2annotations.get(claripy.Not(cond))
            else:
                cond_tags = {}
            return _mapping[cond.op](cond, cond_tags)
        raise NotImplementedError(
            f"Condition variable {cond} has an unsupported operator {cond.op}. Consider implementing."
        )

    def claripy_ast_from_ail_condition(
        self, condition, nobool: bool = False, *, ins_addr: int = 0
    ) -> claripy.ast.Bool | claripy.ast.Bits:
        # Unpack a condition all the way to the leaves
        if isinstance(
            condition, (claripy.ast.Bits, claripy.ast.Bool)
        ):  # pylint:disable=isinstance-second-argument-not-valid-type
            return condition

        if isinstance(
            condition,
            (ailment.Expr.VEXCCallExpression, ailment.Expr.BasePointerOffset, ailment.Expr.ITE),
        ):
            return _dummy_bvs(condition, self._condition_mapping)
        if isinstance(condition, ailment.Stmt.Call):
            return _dummy_bvs(condition, self._condition_mapping, name_suffix=hex(condition.tags.get("ins_addr", 0)))
        if isinstance(condition, (ailment.Expr.Load, ailment.Expr.Register, ailment.Expr.VirtualVariable)):
            # does it have a variable associated?
            if condition.variable is not None:
                var = claripy.BVS(
                    f"ailexpr_{condition!r}-{condition.variable.ident}-{ins_addr:x}",
                    condition.bits,
                    explicit_name=True,
                )
            else:
                var = claripy.BVS(
                    f"ailexpr_{condition!r}-{condition.idx}-{ins_addr:x}", condition.bits, explicit_name=True
                )
            self._condition_mapping[var.args[0]] = condition
            return var
        if isinstance(condition, ailment.Expr.Convert):
            # convert is special. if it generates a 1-bit variable, it should be treated as a BoolS
            if condition.to_bits == 1:
                var_ = self.claripy_ast_from_ail_condition(condition.operands[0], ins_addr=ins_addr)
                name = "ailcond_Conv(%d->%d, %d)" % (condition.from_bits, condition.to_bits, hash(var_))
                var = claripy.BoolS(name, explicit_name=True)
            else:
                var_ = self.claripy_ast_from_ail_condition(condition.operands[0], ins_addr=ins_addr)
                name = "ailexpr_Conv(%d->%d, %d)" % (condition.from_bits, condition.to_bits, hash(var_))
                var = claripy.BVS(name, condition.to_bits, explicit_name=True)
            self._condition_mapping[var.args[0]] = condition
            return var
        if isinstance(condition, ailment.Expr.Const):
            if condition.value is True or condition.value is False:
                var = claripy.BoolV(condition.value)
            else:
                var = claripy.BVV(condition.value, condition.bits)
            if isinstance(var, claripy.ast.Bits) and var.size() == 1:
                var = claripy.true() if var.concrete_value == 1 else claripy.false()
            return var
        if isinstance(condition, ailment.Expr.Tmp):
            l.warning("Left-over ailment.Tmp variable %s.", condition)
            if condition.bits == 1:
                var = claripy.BoolS("ailtmp_%d" % condition.tmp_idx, explicit_name=True)
            else:
                var = claripy.BVS("ailtmp_%d" % condition.tmp_idx, condition.bits, explicit_name=True)
            self._condition_mapping[var.args[0]] = condition
            return var
        if isinstance(condition, ailment.Expr.MultiStatementExpression):
            # just cache it
            if condition.bits == 1:
                var = claripy.BoolS("mstmtexpr_%d" % hash(condition), explicit_name=True)
            else:
                var = claripy.BVS("mstmtexpr_%d" % hash(condition), condition.bits, explicit_name=True)
            self._condition_mapping[var.args[0]] = condition
            return var

        lambda_expr = _ail2claripy_op_mapping.get(condition.verbose_op, None)
        if lambda_expr is None:
            # fall back to op
            lambda_expr = _ail2claripy_op_mapping.get(condition.op, None)
        if lambda_expr is None:
            # fall back to the catch-all option
            l.debug(
                "Unsupported AIL expression operation %s (or verbose: %s). Fall back to the default catch-all dummy "
                "option. Consider implementing.",
                condition.op,
                condition.verbose_op,
            )
            lambda_expr = _ail2claripy_op_mapping["_DUMMY_"]
        r = lambda_expr(condition, self.claripy_ast_from_ail_condition, self._condition_mapping, ins_addr)

        if isinstance(r, claripy.ast.Bool) and nobool:
            r = claripy.BVS(f"ailexpr_from_bool_{r!r}", 1, explicit_name=True)
            self._condition_mapping[r.args[0]] = condition

        if r is NotImplemented:
            if condition.bits == 1 and not nobool:
                r = claripy.BoolS(f"ailexpr_{condition!r}", explicit_name=True)
            else:
                r = claripy.BVS(f"ailexpr_{condition!r}", condition.bits, explicit_name=True)
            self._condition_mapping[r.args[0]] = condition
        # don't lose tags
        self._ast2annotations[r] = condition.tags
        return r

    #
    # Expression simplification
    #

    @staticmethod
    def claripy_ast_to_sympy_expr(ast, memo=None):
        if ast.op == "And":
            return sympy.And(*(ConditionProcessor.claripy_ast_to_sympy_expr(arg, memo=memo) for arg in ast.args))
        if ast.op == "Or":
            return sympy.Or(*(ConditionProcessor.claripy_ast_to_sympy_expr(arg, memo=memo) for arg in ast.args))
        if ast.op == "Not":
            return sympy.Not(ConditionProcessor.claripy_ast_to_sympy_expr(ast.args[0], memo=memo))

        if ast.op in _UNIFIABLE_COMPARISONS:
            # unify comparisons to enable more simplification opportunities without going "deep" in sympy
            inverse_op = getattr(ast.args[0], _INVERSE_OPERATIONS[ast.op])
            return sympy.Not(ConditionProcessor.claripy_ast_to_sympy_expr(inverse_op(ast.args[1]), memo=memo))

        if memo is not None and ast in memo:
            return memo[ast]
        symbol = sympy.Symbol(str(hash(ast)))
        if memo is not None:
            memo[symbol] = ast
        return symbol

    @staticmethod
    def sympy_expr_to_claripy_ast(expr, memo: dict):
        if expr.is_Symbol:
            return memo[expr]
        if isinstance(expr, sympy.Or):
            return claripy.Or(*(ConditionProcessor.sympy_expr_to_claripy_ast(arg, memo) for arg in expr.args))
        if isinstance(expr, sympy.And):
            return claripy.And(*(ConditionProcessor.sympy_expr_to_claripy_ast(arg, memo) for arg in expr.args))
        if isinstance(expr, sympy.Not):
            return claripy.Not(ConditionProcessor.sympy_expr_to_claripy_ast(expr.args[0], memo))
        if isinstance(expr, sympy.logic.boolalg.BooleanTrue):
            return claripy.true()
        if isinstance(expr, sympy.logic.boolalg.BooleanFalse):
            return claripy.false()
        raise AngrRuntimeError("Unreachable reached")

    @staticmethod
    def simplify_condition(cond, depth_limit=8, variables_limit=8):
        memo = {}
        if cond.depth > depth_limit or len(cond.variables) > variables_limit:
            return cond
        sympy_expr = ConditionProcessor.claripy_ast_to_sympy_expr(cond, memo=memo)
        return ConditionProcessor.sympy_expr_to_claripy_ast(sympy.simplify_logic(sympy_expr, deep=False), memo)

    @staticmethod
    def simplify_condition_deprecated(cond):
        # Z3's simplification may yield weird and unreadable results
        # hence we mostly rely on our own simplification. we only use Z3's simplification results when it returns a
        # concrete value.
        claripy_simplified = claripy.simplify(cond)
        if not claripy_simplified.symbolic:
            return claripy_simplified

        simplified = ConditionProcessor._fold_double_negations(cond)
        cond = simplified if simplified is not None else cond
        simplified = ConditionProcessor._revert_short_circuit_conditions(cond)
        cond = simplified if simplified is not None else cond
        simplified = ConditionProcessor._extract_common_subexpressions(cond)
        cond = simplified if simplified is not None else cond
        # simplified = ConditionProcessor._remove_redundant_terms(cond)
        # cond = simplified if simplified is not None else cond
        # in the end, use claripy's simplification to handle really easy cases again
        simplified = ConditionProcessor._simplify_trivial_cases(cond)
        return simplified if simplified is not None else cond

    @staticmethod
    def _simplify_trivial_cases(cond):
        if cond.op == "And":
            new_args = []
            for arg in cond.args:
                claripy_simplified = claripy.simplify(arg)
                if claripy.is_true(claripy_simplified):
                    continue
                new_args.append(arg)
            return claripy.And(*new_args)

        return None

    @staticmethod
    def _revert_short_circuit_conditions(cond):
        # revert short-circuit conditions
        # !A||(A&&!B) ==> !(A&&B)

        if cond.op != "Or":
            return cond

        if len(cond.args) == 1:
            # redundant operator. get rid of it
            return cond.args[0]

        or_arg0, or_arg1 = cond.args[:2]
        if or_arg1.op == "And":
            pass
        elif or_arg0.op == "And":
            or_arg0, or_arg1 = or_arg1, or_arg0
        else:
            return cond

        not_a = or_arg0
        solver = claripy.SolverCacheless()

        if not_a.variables == or_arg1.args[0].variables:
            solver.add(not_a == or_arg1.args[0])
            not_b = or_arg1.args[1]
        elif not_a.variables == or_arg1.args[1].variables:
            solver.add(not_a == or_arg1.args[1])
            not_b = or_arg1.args[0]
        else:
            return cond

        if not solver.satisfiable():
            # found it!
            b = claripy.Not(not_b)
            a = claripy.Not(not_a)
            if len(cond.args) <= 2:
                return claripy.Not(claripy.And(a, b))
            return claripy.Or(claripy.Not(claripy.And(a, b)), *cond.args[2:])
        return cond

    @staticmethod
    def _fold_double_negations(cond):
        # !(!A) ==> A
        # !((!A) && (!B)) ==> A || B
        # !((!A) && B) ==> A || !B
        # !(A || B) ==> (!A && !B)

        if cond.op != "Not":
            return None
        if cond.args[0].op == "Not":
            return cond.args[0]

        if cond.args[0].op == "And" and len(cond.args[0].args) == 2:
            and_0, and_1 = cond.args[0].args
            if and_0.op == "Not" and and_1.op == "Not":
                return claripy.Or(and_0.args[0], and_1.args[0])

            if and_0.op == "Not":  # and_1.op != "Not"
                return claripy.Or(and_0.args[0], ConditionProcessor.simplify_condition(claripy.Not(and_1)))

        if cond.args[0].op == "Or" and len(cond.args[0].args) == 2:
            or_0, or_1 = cond.args[0].args
            return claripy.And(
                ConditionProcessor.simplify_condition(claripy.Not(or_0)),
                ConditionProcessor.simplify_condition(claripy.Not(or_1)),
            )

        return None

    @staticmethod
    def _extract_common_subexpressions(cond):
        def _expr_inside_collection(expr_, coll_) -> bool:
            return any(expr_ is ex_ for ex_ in coll_)

        # (A && B) || (A && C) => A && (B || C)
        if cond.op == "And":
            args = [ConditionProcessor._extract_common_subexpressions(arg) for arg in cond.args]
            if all(arg is None for arg in args):
                return None
            return claripy.And(*((arg if arg is not None else ori_arg) for arg, ori_arg in zip(args, cond.args)))

        if cond.op == "Or":
            args = [ConditionProcessor._extract_common_subexpressions(arg) for arg in cond.args]
            args = [(arg if arg is not None else ori_arg) for arg, ori_arg in zip(args, cond.args)]

            expr_ctrs = defaultdict(int)
            for arg in args:
                if arg.op == "And":
                    for subexpr in arg.args:
                        expr_ctrs[subexpr] += 1
                else:
                    expr_ctrs[arg] += 1

            common_exprs = []
            for expr, ctr in expr_ctrs.items():
                if ctr == len(args):
                    # found a common one
                    common_exprs.append(expr)

            if not common_exprs:
                return claripy.Or(*args)

            new_args = []
            for arg in args:
                if arg.op == "And":
                    new_subexprs = [
                        subexpr for subexpr in arg.args if not _expr_inside_collection(subexpr, common_exprs)
                    ]
                    new_args.append(claripy.And(*new_subexprs))
                elif arg in common_exprs:
                    continue
                else:
                    raise AngrRuntimeError("Unexpected behavior - you should never reach here")

            return claripy.And(*common_exprs, claripy.Or(*new_args))

        return None

    @staticmethod
    def _extract_terms(ast: claripy.ast.Bool) -> Generator[claripy.ast.Bool]:
        if ast.op == "And" or ast.op == "Or":
            for arg in ast.args:
                yield from ConditionProcessor._extract_terms(arg)
        elif ast.op == "Not":
            yield from ConditionProcessor._extract_terms(ast.args[0])
        else:
            yield ast

    @staticmethod
    def _replace_term_in_ast(
        ast: claripy.ast.Bool,
        r0: claripy.ast.Bool,
        r0_with: claripy.ast.Bool,
        r1: claripy.ast.Bool,
        r1_with: claripy.ast.Bool,
    ) -> claripy.ast.Bool:
        if ast.op == "And":
            return claripy.And(
                *(ConditionProcessor._replace_term_in_ast(arg, r0, r0_with, r1, r1_with) for arg in ast.args)
            )
        if ast.op == "Or":
            return claripy.Or(
                *(ConditionProcessor._replace_term_in_ast(arg, r0, r0_with, r1, r1_with) for arg in ast.args)
            )
        if ast.op == "Not":
            return claripy.Not(ConditionProcessor._replace_term_in_ast(ast.args[0], r0, r0_with, r1, r1_with))
        if ast is r0:
            return r0_with
        if ast is r1:
            return r1_with
        return ast

    @staticmethod
    def _remove_redundant_terms(cond):
        """
        Extract all terms and test for each term if its truism impacts the truism of the entire condition. If not, the
        term is redundant and can be replaced with a True.
        """

        all_terms = set()
        for term in ConditionProcessor._extract_terms(cond):
            if term not in all_terms:
                all_terms.add(term)

        negations = {}
        to_skip = set()
        all_terms_without_negs = set()
        for term in all_terms:
            if term in to_skip:
                continue
            neg = claripy.Not(term)
            if neg in all_terms:
                negations[term] = neg
                to_skip.add(neg)
                all_terms_without_negs.add(term)
            else:
                all_terms_without_negs.add(term)

        solver = claripy.SolverCacheless()
        for term in all_terms_without_negs:
            neg = negations.get(term)

            replaced_with_true = ConditionProcessor._replace_term_in_ast(
                cond, term, claripy.true(), neg, claripy.false()
            )
            sat0 = solver.satisfiable(
                extra_constraints=(
                    cond,
                    claripy.Not(replaced_with_true),
                )
            )
            sat1 = solver.satisfiable(
                extra_constraints=(
                    claripy.Not(cond),
                    replaced_with_true,
                )
            )
            if sat0 or sat1:
                continue

            replaced_with_false = ConditionProcessor._replace_term_in_ast(
                cond, term, claripy.false(), neg, claripy.true()
            )
            sat0 = solver.satisfiable(
                extra_constraints=(
                    cond,
                    claripy.Not(replaced_with_false),
                )
            )
            sat1 = solver.satisfiable(
                extra_constraints=(
                    claripy.Not(cond),
                    replaced_with_false,
                )
            )
            if sat0 or sat1:
                continue

            # TODO: Finish the implementation
            print(term, "is redundant")

    #
    # Graph processing
    #

    @staticmethod
    def _remove_crossing_edges_between_cases(
        graph: networkx.DiGraph, case_entry_to_switch_head: dict[int, int]
    ) -> networkx.DiGraph:
        starting_nodes = {node for node in graph if node.addr in case_entry_to_switch_head}
        if not starting_nodes:
            return graph

        traversed_nodes = set()
        edges_to_remove = set()
        for starting_node in starting_nodes:
            queue = [starting_node]
            while queue:
                src = queue.pop(0)
                traversed_nodes.add(src)
                successors = graph.successors(src)
                for succ in successors:
                    if succ in traversed_nodes:
                        # we should not traverse this node twice
                        if graph.out_degree(succ) > 0:
                            edges_to_remove.add((src, succ))
                        continue
                    if succ in starting_nodes:
                        # we do not want any jump from one node to a starting node
                        edges_to_remove.add((src, succ))
                        continue
                    traversed_nodes.add(src)
                    queue.append(succ)

        if not edges_to_remove:
            return graph

        # make a copy before modifying the graph
        graph = networkx.DiGraph(graph)
        graph.remove_edges_from(edges_to_remove)
        return graph

    #
    # Utils
    #

    def create_jump_target_var(self, jumptable_head_addr: int):
        return claripy.BVS(f"jump_table_{jumptable_head_addr:x}", self.arch.bits, explicit_name=True)
